{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "77be7dda",
   "metadata": {},
   "source": [
    "## vLLM-LMI Mixtral-8x7B-DPO-AWQ deployment guide"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d69a078d",
   "metadata": {},
   "source": [
    "### In this tutorial, you will use vllm backend of Large Model Inference(LMI) DLC to deploy Mixtral-8x7B-DPO-AWQ and run inference with it.\n",
    "\n",
    "Please make sure the following permission granted before running the notebook:\n",
    "\n",
    "* S3 bucket push access\n",
    "* SageMaker access\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7bc171a0",
   "metadata": {},
   "source": [
    "### Step 1: Let's bump up SageMaker and import stuff"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bdd49db4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install sagemaker --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "271ce0b4-e9b2-4661-9767-7580ff6f86be",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%pip install transformers sentencepiece --upgrade  --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cc3c2465",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "from sagemaker import Model, image_uris, serializers, deserializers\n",
    "\n",
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "region = sess._region_name  # region name of the current SageMaker Studio environment\n",
    "account_id = sess.account_id()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ab26348",
   "metadata": {},
   "source": [
    "### Step 2: Start preparing model artifacts"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86df559e",
   "metadata": {},
   "source": [
    "In LMI container, we expect some artifacts to help setting up the model\n",
    "\n",
    "* serving.properties (required): Defines the model server settings\n",
    "* model.py (optional): A python file to define the core inference logic\n",
    "* requirements.txt (optional): Any additional pip wheel need to install"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "19ff0dc9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing serving.properties\n"
     ]
    }
   ],
   "source": [
    "%%writefile serving.properties\n",
    "engine=Python\n",
    "option.model_id=TheBloke/Nous-Hermes-2-Mixtral-8x7B-DPO-AWQ\n",
    "option.tensor_parallel_degree=4\n",
    "option.max_rolling_batch_size=64\n",
    "option.rolling_batch=vllm\n",
    "option.task=text-generation\n",
    "option.dtype=fp16\n",
    "option.quantize=awq\n",
    "option.max_model_len=8192"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "79c21521",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mymodel/\n",
      "mymodel/serving.properties\n"
     ]
    }
   ],
   "source": [
    "%%sh\n",
    "mkdir mymodel\n",
    "mv serving.properties mymodel/\n",
    "tar czvf mymodel.tar.gz mymodel/\n",
    "rm -rf mymodel"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a817d0e",
   "metadata": {},
   "source": [
    "### Step 3: Start building SageMaker endpoint"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "42b20125",
   "metadata": {},
   "source": [
    "#### Getting the container image URI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e61a908b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "image_uri = image_uris.retrieve(\n",
    "        framework=\"djl-deepspeed\",\n",
    "        region=sess.boto_session.region_name,\n",
    "        version=\"0.27.0\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "96a08056",
   "metadata": {},
   "source": [
    "#### Upload artifact on S3 and create SageMaker model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f903af58",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "model_name = \"TheBloke/Nous-Hermes-2-Mixtral-8x7B-DPO-AWQ\"\n",
    "s3_code_prefix = f\"large-model-vllm/{model_name}code\"\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "code_artifact = sess.upload_data(\"mymodel.tar.gz\", bucket, s3_code_prefix)\n",
    "print(f\"S3 Code or Model tar ball uploaded to --- > {code_artifact}\")\n",
    "\n",
    "model = Model(image_uri=image_uri, model_data=code_artifact, role=role)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c04c778f",
   "metadata": {},
   "source": [
    "#### Create SageMaker endpoint with a specified instance type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f27a8df",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "instance_type = \"ml.g4dn.12xlarge\"\n",
    "endpoint_name = sagemaker.utils.name_from_base(f\"lmi-model-{model_name.replace('/', '-')}\")\n",
    "print(f\"endpoint_name: {endpoint_name}\")\n",
    "\n",
    "model.deploy(initial_instance_count=1,\n",
    "             instance_type=instance_type,\n",
    "             endpoint_name=endpoint_name,\n",
    "             container_startup_health_check_timeout=1800\n",
    "            )\n",
    "\n",
    "# our requests and responses will be in json format so we specify the serializer and the deserializer\n",
    "predictor = sagemaker.Predictor(\n",
    "    endpoint_name=endpoint_name,\n",
    "    sagemaker_session=sess,\n",
    "    serializer=serializers.JSONSerializer(),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19004557",
   "metadata": {},
   "source": [
    "### Step 4: Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "fd192a22-3cda-4032-bdfd-0491e88cc068",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "system_message=\"\"\n",
    "input_text = \"请解释一下AI\"\n",
    "\n",
    "prompt_template=f'''<|im_start|>system\n",
    "{system_message}<|im_end|>\n",
    "<|im_start|>user\n",
    "{input_text}<|im_end|>\n",
    "<|im_start|>assistant\n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "e426ba55-6cc3-4772-be77-0f99f29128e5",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "parameters = {\n",
    "                \"max_new_tokens\":128,\n",
    "                \"do_sample\":True,\n",
    "                \"temperature\":0.7,\n",
    "                \"top_p\":0.95,\n",
    "                \"top_k\":40,\n",
    "                \"repetition_penalty\":1.1\n",
    "            }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a450b46-8783-41b3-b2f4-1b1bb05e8613",
   "metadata": {
    "tags": []
   },
   "source": [
    "## None Streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "ccbf0046",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 10.5 ms, sys: 4.19 ms, total: 14.7 ms\n",
      "Wall time: 6.42 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'{\"generated_text\": \"AI，全称为人工智能（Artificial Intelligence），是指计算机系统通过学习、自我优化和模拟人类思维方式来执行复杂任务的技术。AI可以从数据中学习并生成预测或决策，以实现各种应用场景，例如图像识别、语音合成、自然语言处理等。AI的主要目标是使计算机系统能够进行类似于人类智能的行动，包括\"}'"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "response = predictor.predict(\n",
    "    {\n",
    "        \"inputs\": prompt_template, \n",
    "         \"parameters\": parameters\n",
    "    }\n",
    ")\n",
    "text = str(response, 'utf-8')\n",
    "text"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93c7d0b5-5ddb-40b3-848d-ce058b1a4b9e",
   "metadata": {},
   "source": [
    "## Streaming"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "62229879-28d2-4ecc-b0cf-c366317e7823",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import json\n",
    "import boto3\n",
    "\n",
    "smr_client = boto3.client(\"sagemaker-runtime\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "533c5232-3764-47fe-b8fb-cb95b4a16d71",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def get_realtime_response_stream(sagemaker_runtime, endpoint_name, payload):\n",
    "    response_stream = sagemaker_runtime.invoke_endpoint_with_response_stream(\n",
    "        EndpointName=endpoint_name,\n",
    "        Body=json.dumps(payload), \n",
    "        ContentType=\"application/json\",\n",
    "        CustomAttributes='accept_eula=false'\n",
    "    )\n",
    "    return response_stream"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "30f4d2e2-fa5b-4e13-925d-6c8bd41958ef",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "payload = {\n",
    "    \"inputs\":  prompt_template,\n",
    "    \"parameters\": parameters,\n",
    "    \"stream\": True ## <-- to have response stream.\n",
    "}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "e289afdd-a048-40a1-be33-8a6628d6dc05",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from utils.LineIterator import LineIterator\n",
    "\n",
    "def print_response_stream(response_stream):\n",
    "    event_stream = response_stream.get('Body')\n",
    "    for line in LineIterator(event_stream):\n",
    "        print(line, end='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "335da0ec-7f78-4815-981c-2294ad466e69",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AI，全称为人工智能（Artificial Intelligence），是一个广泛的研究领域和技术实践。它涉及在计算机系统中模拟、扩展和创造智能行为和决策能力。\n",
      "\n",
      "AI的目标是使计算机系统能够像人类一样学习、理解、推理和进行决策。这意味着AI可以处理复杂的任务、自动化过程、提高效率和降低成本"
     ]
    }
   ],
   "source": [
    "response_stream = get_realtime_response_stream(smr_client, endpoint_name, payload)\n",
    "print_response_stream(response_stream)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5ad8352a-7eec-461d-b161-765a8a7220ba",
   "metadata": {},
   "source": [
    "## Clear resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e9603928-b375-4d87-8c29-63b3cd59306a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sess.delete_endpoint(endpoint_name)\n",
    "sess.delete_endpoint_config(endpoint_name)\n",
    "model.delete_model()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_pytorch_p310",
   "language": "python",
   "name": "conda_pytorch_p310"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
